from BasicLibraries import *
from matplotlib import rcParams
from collections import defaultdict

'''
Just an auxiliary plot function
'''


plt.rcParams['xtick.major.pad']='8'
rcParams['font.family'] = 'serif'
rcParams['font.sans-serif'] = ['Verdana']
#%%

class MatrixPlot():
    def __init__(self):
        pass
    def get_InitiativesMatrix(self, dat, sdgs_list = [], 
                              initiatives_list = [],
                              year_label = 'rfyear', 
                              company_id = 'gvkey', 
                              order = 'year_name'):
        '''
        Given the whole data as input return the initiative-SDG matrix
        by year and firm in a dictionary.
        '''
        if len(initiatives_list) == 0:
            ordered_initiatives = ['communication', 'donation & funding', 'association', 'pricing',
                                   'adoption of standards and rules', 'assessment and measurement','organizational structuring',  'modification of procedures', 'asset modification', 'training',
                                   'incentives', 'r&d investments', 'new products', \
                                       'volunteerism']
        else:
            ordered_initiatives = initiatives_list
        if len(sdgs_list) == 0:
            sdgs = ['SDG '+str(i) for i in range(1, 18)]    
        else:
            sdgs = ['SDG '+str(i) for i in sdgs_list]
        joing = list(np.ravel([[O+' - '+S for S in sdgs] for O in ordered_initiatives]))
        
     
        initiative_matrix, initiative_vector  = defaultdict(dict), defaultdict(dict)        
        dat = dat.sort_values(by = year_label)
        x = dat[joing+[year_label, company_id] ]
        if len(x.columns) != len(sdgs)*len(ordered_initiatives)+2:
            raise Exception("Problem with the data formatting. Control that there are 17 SDG and 14 Initiatives")
    
        for i in range(len(dat)):
            tmp = x.iloc[i]
            M = np.array(tmp.iloc[:-2]).reshape(len(ordered_initiatives),len(sdgs))
            M = pd.DataFrame(M, index = ordered_initiatives, columns = sdgs)
            Y = tmp.iloc[-2]
            N = tmp.iloc[-1]
            if order == 'year_name':
                initiative_matrix[str(Y)][N] = M          
                initiative_vector[str(Y)][N] = tmp.iloc[:-2]
            else:
                initiative_matrix[N][str(Y)] = M         
                initiative_vector[N][str(Y)] = tmp.iloc[:-2]
        
        return initiative_matrix, initiative_vector 
    def MakeMat(self, DT, ax='', sdgs_list = [], initiatives_list = [],  vmn=0,vmx=1000, rounded = 0, percentage = True, show_plot=True, sdg = 'all', dpi_ = 100):
        '''
        This is the visualization of the environmental matrix
        '''
        tmp_X = DT.copy()
        matrix_dict, vec = self.get_InitiativesMatrix(tmp_X, sdgs_list, initiatives_list)
        mat = []
        for Y in matrix_dict.keys():
            tmp = matrix_dict[Y]
            for n in tmp.keys():
                LD = tmp[n]
                mat.append(LD)
        AV = np.sum(mat, axis = 0)
        AV = pd.DataFrame(AV, columns = tmp[n].columns, index = tmp[n].index).astype(int)
        #==============================
        X = AV.copy()
        if sdg == 'all':
            X = X[['SDG '+str(k) for k in range(1, 18)]]
            fgsz = (24, 14)
        else:
            X = X[['SDG '+str(S) for S in sdg]]
            if ax == '':  fgsz = (14,14)
        if percentage:
            X = (X/X.sum().sum())*100
        X = X.round(2)

        tot = X.sum().sum()
        col_sum = X.sum().apply(np.round)
        col_sum.name = 'Total'
        row_sum = X.sum(axis = 1).apply(np.round)
        row_sum.name = 'Total'
        mx_range = X.max().max()
        mn_range = X.min().min()
        X['Total'] = row_sum
        X = X.append(col_sum)
        X.loc['Total' ,'Total'] = np.round(tot)
        X = X.astype(float)
        X = np.round(X, rounded)
        X = X.sort_values(by = 'Total', axis = 0, ascending = False)
        X = X.sort_values(by = 'Total', axis = 1, ascending = False)
        X  = X.loc[list(X.index[1:]) + ['Total']]
        X  = X[list(X.columns[1:]) + ['Total']]
        
        if show_plot:
            cm = 'RdYlBu' #'PiYG'
            plt.figure(figsize = fgsz, dpi = dpi_)
            if ax == '':
                ax = sns.heatmap(X, annot = True, cbar =  False, fmt  = '.6g', 
                                 vmin = vmn, vmax =vmx , 
                                 center = 0, linewidths = 0.8, 
                                 annot_kws={"color": "k", 'size': '24'},\
                                 linecolor = 'black', alpha = 0.6,
                                 cmap = cm)
            else:
                sns.heatmap(X, annot = True, cbar =  False, fmt  = '.6g', 
                            vmin = vmn, vmax =vmx , 
                            center = 0, linewidths = 0.8, 
                            annot_kws={"color": "k", 'size': '24'},\
                            linecolor = 'black', alpha = 0.6,
                            cmap = cm, ax = ax)
    
            ax.yaxis.tick_right()
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize = 24)
            ax.set_xticklabels(ax.get_xticklabels(), rotation=0, fontsize = 24)
            plt.ylim(len(X)+1, -1)
            if sdg == 'all':
                ax.hlines([5, 10, 14], color =  LCBlue, linewidths = 3, *ax.get_xlim())
                ax.vlines([6, 12, 18], color =  LCBlue, linewidths = 3, ymin = 0, ymax = 15)             
            else:
                ax.hlines([14], color =  'navy', linewidths = 3, *ax.get_xlim())
                ax.vlines([7], color =  'navy', linewidths = 3, ymin = 0, ymax = 15)   
        return X
    
  
    def SimpleMatrixPlot(self, X,  ax = '', rounded = 1, dpi_=100, significance=False,plotTotal=True,plot_mat=True,
                         vmn = -150, vmx = 300, ctr=0):

        if 'Total' in list(X.columns):
            X = X.drop(columns = ['Total'], index = ['Total'])
        if plotTotal:
            tot = X.sum().sum()
            col_sum = np.round(X.sum(), 1)
            col_sum.name = 'Total'
            row_sum = np.round(X.sum(axis = 1), 1)
            row_sum.name = 'Total'
            X['Total'] = row_sum
            X = X.append(col_sum)
            X.loc['Total' ,'Total'] = np.round(tot, 0)
            X = X.sort_values(by = 'Total', axis = 0, ascending = False)
            X = X.sort_values(by = 'Total', axis = 1, ascending = False)
            cols = list(X.drop(columns = ['Total']).columns)
            rows = list(X.drop(index = ['Total']).index )
            X  = X.loc[rows + ['Total']]
            X  = X[cols+ ['Total']]
            
        else:
            tot = X.mean().mean()
            col_sum = np.round(X.mean(),1)
            col_sum.name = 'Average'
            row_sum = np.round(X.mean(axis = 1),1)
            row_sum.name = 'Average'
            X['Average'] = row_sum
            X = X.append(col_sum)
            X.loc['Average' ,'Average'] = np.round(tot,0)
            X = X.sort_values(by = 'Average', axis = 0, ascending = False)
            X = X.sort_values(by = 'Average', axis = 1, ascending = False)
            cols = list(X.drop(columns = ['Average']).columns)
            rows = list(X.drop(index = ['Average']).index )
            X  = X.loc[rows + ['Average']]
            X  = X[cols+ ['Average']]
        if plot_mat:
            if ax == '': plt.figure(figsize = (14, 12), dpi = dpi_)
            cm = 'RdYlBu' #'PiYG'
            plt.title('Behavioral differences', fontsize = 28)
            if ax == '':
                ax = sns.heatmap(X, annot = True, cbar =  False, fmt  = '.6g', 
                                 vmin = vmn,vmax=vmx,
                                 center = ctr, linewidths = 0.8, 
                                 annot_kws={"color": "k", 'size': '21'},\
                                 linecolor = 'black', alpha = 0.6,
                                 cmap = cm)
            else:
                sns.heatmap(X, annot = True, cbar =  False, fmt  = '.6g', 
                            vmin = vmn,vmax=vmx,
                            center = ctr, linewidths = 0.8, 
                            annot_kws={"color": "k", 'size': '21'},\
                            linecolor = 'black', alpha = 0.6,
                            cmap = cm, ax = ax)

            ax.yaxis.tick_right()
            ax.set_yticklabels(ax.get_yticklabels(), rotation=0)#, fontsize = 28)
            ax.set_xticklabels(ax.get_xticklabels(), rotation=0)# , fontsize = 28)
            plt.ylim(len(X)+1, -1)
            ax.hlines([14], color =  'navy', linewidths = 3, *ax.get_xlim())
            ax.vlines([7], color =  'navy', linewidths = 3, ymin = 0, ymax = 15)   
        return X